# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
from __future__ import annotations

from typing import Optional, Union

import numpy as np

from sisl import Lattice, LatticeChild
from sisl._internal import set_module
from sisl.messages import deprecate_argument
from sisl.physics.brillouinzone import BrillouinZone
from sisl.unit.siesta import unit_convert

from ..sile import add_sile, sile_fh_open, sile_raise_write
from .sile import SileSiesta

__all__ = ["kpSileSiesta", "rkpSileSiesta"]

Bohr2Ang = unit_convert("Bohr", "Ang")

TLattice = Optional[Union[Lattice, LatticeChild]]


@set_module("sisl.io.siesta")
class kpSileSiesta(SileSiesta):
    """k-points file in 1/Bohr units"""

    @sile_fh_open()
    @deprecate_argument("sc", "lattice", "use lattice= instead of sc=", "0.15", "0.16")
    def read_data(self, lattice: TLattice = None):
        """Returns K-points from the file (note that these are in reciprocal units)

        Parameters
        ----------
        lattice :
           if supplied the returned k-points will be in reduced coordinates

        Returns
        -------
        k : k-points, in units 1/Bohr
        w : weights for k-points
        """
        nk = int(self.readline())

        k = np.empty([nk, 3], np.float64)
        w = np.empty([nk], np.float64)
        for ik in range(nk):
            l = self.readline().split()
            k[ik, :] = float(l[1]), float(l[2]), float(l[3])
            w[ik] = float(l[4])

        # Correct units to 1/Ang
        k /= Bohr2Ang

        if lattice is None:
            return k, w
        return np.dot(k, lattice.cell.T / (2 * np.pi)), w

    @sile_fh_open()
    def write_data(self, k, weight, fmt: str = ".9e"):
        """Writes K-points to file

        Parameters
        ----------
        k : array_like
           k-points in units 1/Bohr
        weight : array_like
           same length as k, weights of k-points
        fmt : str, optional
           format for the k-values
        """
        sile_raise_write(self)

        nk = len(k)
        self._write(f"{nk}\n")
        _fmt = ("{:d}" + (" {:" + fmt + "}") * 4) + "\n"

        for i, (kk, w) in enumerate(zip(np.atleast_2d(k), weight)):
            self._write(_fmt.format(i + 1, kk[0], kk[1], kk[2], w))

    @sile_fh_open()
    @deprecate_argument("sc", "lattice", "use lattice= instead of sc=", "0.15", "0.16")
    def read_brillouinzone(self, lattice: TLattice) -> BrillouinZone:
        """Returns K-points from the file (note that these are in reciprocal units)

        Parameters
        ----------
        lattice : LatticeChild
           required supercell for the BrillouinZone object

        Returns
        -------
        bz : BrillouinZone
        """
        k, w = self.read_data(lattice)
        from sisl.physics.brillouinzone import BrillouinZone

        bz = BrillouinZone(lattice)
        bz._k = k
        bz._w = w
        return bz

    @sile_fh_open()
    def write_brillouinzone(self, bz: BrillouinZone, fmt: str = ".9e"):
        """Writes BrillouinZone-points to file

        Parameters
        ----------
        bz : BrillouinZone
           object contain all weights and k-points
        fmt : str, optional
           format for the k-values
        """
        # And convert to 1/Bohr
        k = bz.tocartesian(bz.k) * Bohr2Ang
        self.write_data(k, bz.weight, fmt)


@set_module("sisl.io.siesta")
class rkpSileSiesta(kpSileSiesta):
    """Special k-point file with units in reciprocal lattice vectors

    Its main usage is as input for the kgrid.File fdf-option, in which case this
    file provides the k-points in the correct format.
    """

    @sile_fh_open()
    def read_data(self):
        """Returns K-points from the file (note that these are in reciprocal units)

        Returns
        -------
        k : k-points, in units of the reciprocal lattice vectors
        w : weights for k-points
        """
        nk = int(self.readline())

        k = np.empty([nk, 3], np.float64)
        w = np.empty([nk], np.float64)
        for ik in range(nk):
            l = self.readline().split()
            k[ik, :] = float(l[1]), float(l[2]), float(l[3])
            w[ik] = float(l[4])

        return k, w

    @sile_fh_open()
    @deprecate_argument("sc", "lattice", "use lattice= instead of sc=", "0.15", "0.16")
    def read_brillouinzone(self, lattice: TLattice) -> BrillouinZone:
        """Returns K-points from the file

        Parameters
        ----------
        lattice : LatticeChild
           required supercell for the BrillouinZone object

        Returns
        -------
        bz : BrillouinZone
        """
        k, w = self.read_data()
        from sisl.physics.brillouinzone import BrillouinZone

        bz = BrillouinZone(lattice)
        bz._k = k
        bz._w = w
        return bz

    @sile_fh_open()
    def write_brillouinzone(self, bz: BrillouinZone, fmt: str = ".9e"):
        """Writes BrillouinZone-points to file

        Parameters
        ----------
        bz : BrillouinZone
           object contain all weights and k-points
        fmt : str, optional
           format for the k-values
        """
        self.write_data(bz.k, bz.weight, fmt)


add_sile("KP", kpSileSiesta, gzip=True)
add_sile("RKP", rkpSileSiesta, gzip=True)
